package cloud.seri.gateway.security.oauth2

import cloud.seri.gateway.config.oauth2.OAuth2Properties
import cloud.seri.gateway.web.filter.RefreshTokenFilter
import cloud.seri.gateway.web.rest.errors.InvalidPasswordException
import io.github.jhipster.config.JHipsterProperties
import org.junit.Assert
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.ExpectedException
import org.junit.runner.RunWith
import org.mockito.Mock
import org.mockito.junit.MockitoJUnitRunner
import org.springframework.http.HttpEntity
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.http.ResponseEntity
import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpServletResponse
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken
import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken
import org.springframework.security.oauth2.common.OAuth2AccessToken
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException
import org.springframework.security.oauth2.provider.token.TokenStore
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
import org.springframework.web.client.HttpClientErrorException
import org.springframework.web.client.RestTemplate

import javax.servlet.http.Cookie
import java.util.Date

import org.mockito.Mockito.`when`

/**
 * Test password and refresh token grants.
 *
 * @see OAuth2AuthenticationService
 */
@RunWith(MockitoJUnitRunner::class)
class OAuth2AuthenticationServiceTest {

    @Mock
    private lateinit var restTemplate: RestTemplate
    @Mock
    private lateinit var tokenStore: TokenStore
    private lateinit var authenticationService: OAuth2AuthenticationService
    private lateinit var refreshTokenFilter: RefreshTokenFilter

    @Rule
    @JvmField
    var expectedException: ExpectedException = ExpectedException.none()

    @Before
    fun init() {
        val oAuth2Properties = OAuth2Properties()
        val jHipsterProperties = JHipsterProperties()
        jHipsterProperties.security.clientAuthorization.accessTokenUri = "http://uaa/oauth/token"
        val cookieHelper = OAuth2CookieHelper(oAuth2Properties)
        val accessToken = createAccessToken(ACCESS_TOKEN_VALUE, REFRESH_TOKEN_VALUE)

        mockInvalidPassword()
        mockPasswordGrant(accessToken)
        mockRefreshGrant()

        val authorizationClient = UaaTokenEndpointClient(restTemplate, jHipsterProperties, oAuth2Properties)
        authenticationService = OAuth2AuthenticationService(authorizationClient, cookieHelper)
        `when`(tokenStore.readAccessToken(ACCESS_TOKEN_VALUE)).thenReturn(accessToken)
        refreshTokenFilter = RefreshTokenFilter(authenticationService, tokenStore)
    }

    private fun mockInvalidPassword() {
        val reqHeaders = HttpHeaders()
        reqHeaders.contentType = MediaType.APPLICATION_FORM_URLENCODED
        reqHeaders.add("Authorization", CLIENT_AUTHORIZATION)                //take over Authorization header from client request to UAA request
        val formParams = LinkedMultiValueMap<String, String>()
        formParams.set("username", "user")
        formParams.set("password", "user2")
        formParams.add("grant_type", "password")
        val entity = HttpEntity<MultiValueMap<String, String>>(formParams, reqHeaders)
        `when`(restTemplate.postForEntity("http://uaa/oauth/token", entity, OAuth2AccessToken::class.java))
            .thenThrow(HttpClientErrorException(HttpStatus.BAD_REQUEST))
    }

    private fun mockPasswordGrant(accessToken: OAuth2AccessToken) {
        val reqHeaders = HttpHeaders()
        reqHeaders.contentType = MediaType.APPLICATION_FORM_URLENCODED
        reqHeaders.add("Authorization", CLIENT_AUTHORIZATION)                //take over Authorization header from client request to UAA request
        val formParams = LinkedMultiValueMap<String, String>()
        formParams.set("username", "user")
        formParams.set("password", "user")
        formParams.add("grant_type", "password")
        val entity = HttpEntity<MultiValueMap<String, String>>(formParams, reqHeaders)
        `when`(restTemplate.postForEntity("http://uaa/oauth/token", entity, OAuth2AccessToken::class.java))
            .thenReturn(ResponseEntity(accessToken, HttpStatus.OK))
    }

    private fun mockRefreshGrant() {
        val params = LinkedMultiValueMap<String, String>()
        params.add("grant_type", "refresh_token")
        params.add("refresh_token", REFRESH_TOKEN_VALUE)
        //we must authenticate with the UAA server via HTTP basic authentication using the browser's client_id with no client secret
        val headers = HttpHeaders()
        headers.add("Authorization", CLIENT_AUTHORIZATION)
        val entity = HttpEntity<MultiValueMap<String, String>>(params, headers)
        val newAccessToken = createAccessToken(NEW_ACCESS_TOKEN_VALUE, NEW_REFRESH_TOKEN_VALUE)
        `when`(restTemplate.postForEntity("http://uaa/oauth/token", entity, OAuth2AccessToken::class.java))
            .thenReturn(ResponseEntity(newAccessToken, HttpStatus.OK))
    }

    @Test
    fun testAuthenticationCookies() {
        val request = MockHttpServletRequest()
        request.serverName = "www.test.com"
        request.addHeader("Authorization", CLIENT_AUTHORIZATION)
        val params = mutableMapOf<String, String>()
        params["username"] = "user"
        params["password"] = "user"
        params["rememberMe"] = "true"
        val response = MockHttpServletResponse()
        authenticationService.authenticate(request, response, params)
        //check that cookies are set correctly
        val accessTokenCookie = response.getCookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE)
        Assert.assertEquals(ACCESS_TOKEN_VALUE, accessTokenCookie!!.value)
        val refreshTokenCookie = response.getCookie(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE)
        Assert.assertEquals(REFRESH_TOKEN_VALUE, OAuth2CookieHelper.getRefreshTokenValue(refreshTokenCookie!!))
        Assert.assertTrue(OAuth2CookieHelper.isRememberMe(refreshTokenCookie))
    }

    @Test
    fun testAuthenticationNoRememberMe() {
        val request = MockHttpServletRequest()
        request.serverName = "www.test.com"
        val params = mutableMapOf<String, String>()
        params["username"] = "user"
        params["password"] = "user"
        params["rememberMe"] = "false"
        val response = MockHttpServletResponse()
        authenticationService.authenticate(request, response, params)
        //check that cookies are set correctly
        val accessTokenCookie = response.getCookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE)
        Assert.assertEquals(ACCESS_TOKEN_VALUE, accessTokenCookie!!.value)
        val refreshTokenCookie = response.getCookie(OAuth2CookieHelper.SESSION_TOKEN_COOKIE)
        Assert.assertEquals(REFRESH_TOKEN_VALUE, OAuth2CookieHelper.getRefreshTokenValue(refreshTokenCookie!!))
        Assert.assertFalse(OAuth2CookieHelper.isRememberMe(refreshTokenCookie))
    }

    @Test
    fun testInvalidPassword() {
        val request = MockHttpServletRequest()
        request.serverName = "www.test.com"
        val params = mutableMapOf<String, String>()
        params["username"] = "user"
        params["password"] = "user2"
        params["rememberMe"] = "false"
        val response = MockHttpServletResponse()
        expectedException.expect(InvalidPasswordException::class.java)
        authenticationService.authenticate(request, response, params)
    }

    @Test
    fun testRefreshGrant() {
        val request = createMockHttpServletRequest()
        val response = MockHttpServletResponse()
        val newRequest = refreshTokenFilter.refreshTokensIfExpiring(request, response)
        val newAccessTokenCookie = response.getCookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE)
        Assert.assertEquals(NEW_ACCESS_TOKEN_VALUE, newAccessTokenCookie!!.value)
        val newRefreshTokenCookie = response.getCookie(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE)
        Assert.assertEquals(NEW_REFRESH_TOKEN_VALUE, newRefreshTokenCookie!!.value)
        val requestAccessTokenCookie = OAuth2CookieHelper.getAccessTokenCookie(newRequest)
        Assert.assertEquals(NEW_ACCESS_TOKEN_VALUE, requestAccessTokenCookie?.value)
    }

    @Test
    fun testSessionExpired() {
        val request = MockHttpServletRequest(HttpMethod.GET.name, "http://www.test.com")
        val accessTokenCookie = Cookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE, ACCESS_TOKEN_VALUE)
        val refreshTokenCookie = Cookie(OAuth2CookieHelper.SESSION_TOKEN_COOKIE, EXPIRED_SESSION_TOKEN_VALUE)
        request.setCookies(accessTokenCookie, refreshTokenCookie)
        val response = MockHttpServletResponse()
        val newRequest = refreshTokenFilter.refreshTokensIfExpiring(request, response)
        //cookies in response are deleted
        val newAccessTokenCookie = response.getCookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE)
        Assert.assertEquals(0, newAccessTokenCookie!!.maxAge.toLong())
        val newRefreshTokenCookie = response.getCookie(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE)
        Assert.assertEquals(0, newRefreshTokenCookie!!.maxAge.toLong())
        //request no longer contains cookies
        val requestAccessTokenCookie = OAuth2CookieHelper.getAccessTokenCookie(newRequest)
        Assert.assertNull(requestAccessTokenCookie)
        val requestRefreshTokenCookie = OAuth2CookieHelper.getRefreshTokenCookie(newRequest)
        Assert.assertNull(requestRefreshTokenCookie)
    }

    /**
     * If no refresh token is found and the access token has expired, then expect an exception.
     */
    @Test
    fun testRefreshGrantNoRefreshToken() {
        val request = MockHttpServletRequest(HttpMethod.GET.name, "http://www.test.com")
        val accessTokenCookie = Cookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE, ACCESS_TOKEN_VALUE)
        request.setCookies(accessTokenCookie)
        val response = MockHttpServletResponse()
        expectedException.expect(InvalidTokenException::class.java)
        refreshTokenFilter.refreshTokensIfExpiring(request, response)
    }

    @Test
    fun testLogout() {
        val request = MockHttpServletRequest()
        val accessTokenCookie = Cookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE, ACCESS_TOKEN_VALUE)
        val refreshTokenCookie = Cookie(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE, REFRESH_TOKEN_VALUE)
        request.setCookies(accessTokenCookie, refreshTokenCookie)
        val response = MockHttpServletResponse()
        authenticationService.logout(request, response)
        val newAccessTokenCookie = response.getCookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE)
        Assert.assertEquals(0, newAccessTokenCookie!!.maxAge.toLong())
        val newRefreshTokenCookie = response.getCookie(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE)
        Assert.assertEquals(0, newRefreshTokenCookie!!.maxAge.toLong())
    }

    @Test
    fun testStripTokens() {
        val request = createMockHttpServletRequest()
        val newRequest = authenticationService.stripTokens(request)
        val cookies = CookieCollection(*newRequest.cookies)
        Assert.assertFalse(cookies.contains(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE))
        Assert.assertFalse(cookies.contains(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE))
    }

    companion object {

        const val CLIENT_AUTHORIZATION = "Basic d2ViX2FwcDpjaGFuZ2VpdA=="
        const val ACCESS_TOKEN_VALUE = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0OTQyNzI4NDQsInVzZXJfbmFtZSI6InVzZXIiLCJhdXRob3JpdGllcyI6WyJST0xFX1VTRVIiXSwianRpIjoiNzc1ZTJkYWUtYWYzZi00YTdhLWExOTktNzNiZTU1MmIxZDVkIiwiY2xpZW50X2lkIjoid2ViX2FwcCIsInNjb3BlIjpbIm9wZW5pZCJdfQ.gEK0YcX2IpkpxnkxXXHQ4I0xzTjcy7edqb89ukYE0LPe7xUcZVwkkCJF_nBxsGJh2jtA6NzNLfY5zuL6nP7uoAq3fmvsyrcyR2qPk8JuuNzGtSkICx3kPDRjAT4ST8SZdeh7XCbPVbySJ7ZmPlRWHyedzLA1wXN0NUf8yZYS4ELdUwVBYIXSjkNoKqfWm88cwuNr0g0teypjPtjDqCnXFt1pibwdfIXn479Y1neNAdvSpHcI4Ost-c7APCNxW2gqX-0BItZQearxRgKDdBQ7CGPAIky7dA0gPuKUpp_VCoqowKCXqkE9yKtRQGIISewtj2UkDRZePmzmYrUBXRzfYw"
        const val REFRESH_TOKEN_VALUE = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX25hbWUiOiJ1c2VyIiwic2NvcGUiOlsib3BlbmlkIl0sImF0aSI6Ijc3NWUyZGFlLWFmM2YtNGE3YS1hMTk5LTczYmU1NTJiMWQ1ZCIsImV4cCI6MTQ5Njg2NDc0MywiYXV0aG9yaXRpZXMiOlsiUk9MRV9VU0VSIl0sImp0aSI6IjhmYjI2YTllLTdjYzQtNDFlMi1hNzBjLTk4MDc0N2U2YWFiOSIsImNsaWVudF9pZCI6IndlYl9hcHAifQ.q1-Df9_AFO6TJNiLKV2YwTjRbnd7qcXv52skXYnog5siHYRoR6cPtm6TNQ04iDAoIHljTSTNnD6DS3bHk41mV55gsSVxGReL8VCb_R8ZmhVL4-5yr90sfms0wFp6lgD2bPmZ-TXiS2Oe9wcbNWagy5RsEplZ-sbXu3tjmDao4FN35ojPsXmUs84XnNQH3Y_-PY9GjZG0JEfLQIvE0J5BkXS18Z015GKyA6GBIoLhAGBQQYyG9m10ld_a9fD5SmCyCF72Jad_pfP1u8Z_WyvO-wrlBvm2x-zBthreVrXU5mOb9795wJEP-xaw3dXYGjht_grcW4vKUFtj61JgZk98CQ"
        const val EXPIRED_SESSION_TOKEN_VALUE = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX25hbWUiOiJ1c2VyIiwic2NvcGUiOlsib3BlbmlkIl0sImF0aSI6IjE0NTkwYzdkLTQ5M2YtNDU0NS05MzlmLTg1ODM4ZjRmNzNmNSIsImV4cCI6MTQ5NTU3Mjg5MywiaWF0IjoxNDk1MzIwODkzLCJhdXRob3JpdGllcyI6WyJST0xFX1VTRVIiXSwianRpIjoiNzVhYTIxNzEtMzFmNi00MWJmLWExZGUtYWU0YTg1ZjZiMjEyIiwiY2xpZW50X2lkIjoid2ViX2FwcCJ9.gAH-yly7WAslQUeGhyHmjYXwQN3dluvoT84iOJ2mVWYGVlnDRsoxN3_d1ozqtiso9UM7dWpAr80o3gK7AyK-cO1GGBXa3lg0ETsbucFoqHLivgGZA2qVOsFlDq8E7DZENAbOWmywmhFUOogCfZ-BqsuFSi8waMLL-1qlhehBPuK1KzGxIZbjSVUFFFYTxoWPKi2NNTBzYSwwCV0ixj-gHyFC6Gl5ByA4EvYygGUZF2pACxs4tIRkmT90pXWCjWeKS9k9MlxZ7C4UHqyTRW-IYzqAm8OHdwsnXeu0GkFYc08gxoUuPcjMby8ziYLG5uWj0Ua0msmiSjoafzs-5xfH-Q"
        const val NEW_ACCESS_TOKEN_VALUE = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0OTQyNzY2NDEsInVzZXJfbmFtZSI6InVzZXIiLCJhdXRob3JpdGllcyI6WyJST0xFX1VTRVIiXSwianRpIjoiYzIyY2YzMDgtZTIyYi00YzNjLWI5MjctOTYwYzA2YmY1ZmU0IiwiY2xpZW50X2lkIjoid2ViX2FwcCIsInNjb3BlIjpbIm9wZW5pZCJdfQ.IAhE39GCqWRUuXdWy-raOcE9NYXRhGiqkeJH649501LeqNPH5HtRUNWmudVRgwT52Bj7HcbJapMLGetKIMEASqC1-WARfcZ_PR0r7Kfg3OlFALWOH_oVT5kvi2H-QCoSAF9mRYK6abCh_tPk5KryVB5c7YxTMIXDT2nTsSexD8eNQOMBWRCg0RaLHZ9bKfeyVgncQJsu7-vTo1xJyh-keYpdNZ0TA2SjYJgezmB7gwW1Kmc7_83htr8VycG7XA_PuD9--yRNlrN0LtNHEBqNypZsOe6NvpKiNlodFYHlsU1CaumzcF9U7dpVanjIUKJ5VRWVUlSFY6JJ755W29VCTw"
        const val NEW_REFRESH_TOKEN_VALUE = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX25hbWUiOiJ1c2VyIiwic2NvcGUiOlsib3BlbmlkIl0sImF0aSI6ImMyMmNmMzA4LWUyMmItNGMzYy1iOTI3LTk2MGMwNmJmNWZlNCIsImV4cCI6MTQ5Njg2ODU4MSwiYXV0aG9yaXRpZXMiOlsiUk9MRV9VU0VSIl0sImp0aSI6ImU4YmZhZWJlLWYzMDItNGNjZS1hZGY1LWQ4MzE5OWM1MjBlOSIsImNsaWVudF9pZCI6IndlYl9hcHAifQ.OemWBUfc-2rl4t4VVqolYxul3L527PbSbX2Xvo7oyy3Vy5nmmblqp4hVGdTEjivrlldGVQX03ERbrA-oFkpmfWbBzLvnKS6AUq1MGjut6dXZJeiEqNYmiAABn6jSgK26S0k6b2ADgmf7mxJO8EBypb5sT1DMAbY5cbOe7r4ZG7zMTVSvlvjHTXp_FM8Y9i6nehLD4XDYY57cb_ZA89vAXNzvTAjoopDliExgR0bApG6nvvDEhEYgTS65lccEQocoev6bISJ3RvNYNPJxWcNPftKDp4HrEt2E2WP28K5IivRtQgDQNlQeormf1tp6AG-Oj__NXyAPM7yhAKXNy2zWdQ"

        fun createAccessToken(accessTokenValue: String, refreshTokenValue: String): OAuth2AccessToken {
            val accessToken = DefaultOAuth2AccessToken(accessTokenValue)
            accessToken.expiration = Date()          //token expires now
            val refreshToken = DefaultOAuth2RefreshToken(refreshTokenValue)
            accessToken.refreshToken = refreshToken
            return accessToken
        }

        fun createMockHttpServletRequest(): MockHttpServletRequest {
            val request = MockHttpServletRequest(HttpMethod.GET.name, "http://www.test.com")
            val accessTokenCookie = Cookie(OAuth2CookieHelper.ACCESS_TOKEN_COOKIE, ACCESS_TOKEN_VALUE)
            val refreshTokenCookie = Cookie(OAuth2CookieHelper.REFRESH_TOKEN_COOKIE, REFRESH_TOKEN_VALUE)
            request.setCookies(accessTokenCookie, refreshTokenCookie)
            return request
        }
    }
}
